{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A Simple Mistral 7B Offline RAG with Reranker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../images/simple-rag-1-reranker.webp\" width=700>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1) Load the embedding model and LLM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d6dd568bd5942aca2b4b1990dced1c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36d2e512cb794ccd992ea0167d89a4b1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "510181fb957d4aa8b207e5db6bf2670c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f271bb19ecbe411193f930b0e911cd3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ffda14ee1a342188e0b802c2a04701b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5ae79ed84cc942be877a84f3ad20246b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1634e2eb83ab488084a421fc887f02a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17d7ef95ba824187ab1f393fc16640c4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "706913e2c5664880aec083155d25c65a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7e48b4b2c454791bba4154c93a28a94",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "031abb874a614b01801da528b209f840",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "af903888c44b4719adea490d3112ff72",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from llama_index.core import Settings, PromptTemplate\n",
    "from llama_index.core.embeddings import resolve_embed_model\n",
    "from llama_index.embeddings.huggingface import HuggingFaceEmbedding\n",
    "from llama_index.llms.huggingface import HuggingFaceLLM\n",
    "\n",
    "import torch\n",
    "\n",
    "Settings.embed_model = HuggingFaceEmbedding(model_name=\"BAAI/bge-small-en-v1.5\")\n",
    "\n",
    "Settings.llm = HuggingFaceLLM(\n",
    "    context_window=2048,\n",
    "    max_new_tokens=256,\n",
    "    generate_kwargs={\"temperature\": 0.25, \"do_sample\": False},\n",
    "    model_name=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
    "    tokenizer_name=\"mistralai/Mistral-7B-Instruct-v0.2\",\n",
    "    device_map=\"auto\",\n",
    "    \n",
    "    model_kwargs={\n",
    "        \"torch_dtype\": torch.bfloat16,\n",
    "        # Since we are using a small GPU with limited memory\n",
    "        # Set to False if you have a large GPU to speed things up.\n",
    "        \"offload_buffers\": True, \n",
    "    }\n",
    ")\n",
    "\n",
    "Settings.chunk_size = 512"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2) Load data / read in documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Read documents: {'Basic-Scientific-Food-Preparation-Lab-Manual.pdf', 'Basic-Scientific-Food-Preparation-Lab-Manual.txt'}\n"
     ]
    }
   ],
   "source": [
    "from pathlib import Path\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
    "\n",
    "data_dir =  Path(\"..\") / \"data\" / \"example-1\"\n",
    "\n",
    "documents = SimpleDirectoryReader(data_dir).load_data()\n",
    "\n",
    "# Sanity checks\n",
    "unique_docs = set(d.metadata[\"file_name\"] for d in documents)\n",
    "print(f\"Read documents: {unique_docs}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3) Create vector database\n",
    "\n",
    "- VectorStoreIndex is an in-memory vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7f6a80b589e4e768aee1d065e5f0eaf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Parsing nodes:   0%|          | 0/252 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc68b7f799e74a568e48e1a844efacd3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating embeddings:   0%|          | 0/567 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Database consists of 252 chunks\n"
     ]
    }
   ],
   "source": [
    "index = VectorStoreIndex.from_documents(\n",
    "    documents,\n",
    "    show_progress=True\n",
    ")\n",
    "\n",
    "num_chunks = len(documents)\n",
    "print(f\"Database consists of {num_chunks} chunks\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4) Set up custom prompt template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import PromptTemplate\n",
    "\n",
    "template = (\n",
    "    \"We have provided context information below. \\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"{context_str}\"\n",
    "    \"\\n---------------------\\n\"\n",
    "    \"Given this information, please answer the question: {query_str}\\n\"\n",
    ")\n",
    "qa_template = PromptTemplate(template)\n",
    "query_engine = index.as_query_engine()\n",
    "query_engine.update_prompts(\n",
    "    {\"response_synthesizer:text_qa_template\": qa_template}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5) Query the vector database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.25` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
      "  warnings.warn(\n",
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Dehydration is a method of preserving food by removing most of the water content. It can be done through various methods such as sun drying, oven drying, or using a dehydrator. Properly dehydrated food can be stored for long periods of time without spoiling and can be rehydrated or used as is for consumption. In this lab, students will be pretreating, drying, storing, and serving some commonly dehydrated fruits and vegetables, and observing the characteristics of various dried fruits.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"What is dehydration?\")\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Answer: 2 tbsp. butter are required for the Cream Puffs.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"How many tbsp. butter to use for the Cream Puffs?\")\n",
    "print(response)\n",
    "# Correct answer is 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Answer: 2 tbsp. butter are required for the Cream Puffs.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\"How many tbsp. butter to use for the Cream Puffs?\")\n",
    "print(response)\n",
    "# Correct answer is 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "The toppings for the Braided Bread of the Braids, Coffeecake, and Sweet Rolls recipe are:\n",
      "\n",
      "1. 2 tsp. caraway seeds and ½ cup shredded Cheddar cheese.\n",
      "2. ½ cup diced Swiss cheese and paprika.\n"
     ]
    }
   ],
   "source": [
    "response = query_engine.query(\n",
    "    \"What are the toppings for the Braided Bread \"\n",
    "    \"of the Braids, Coffeecake, and Sweet Rolls recipe?\"\n",
    ")\n",
    "print(response)\n",
    "# Correct answer is \n",
    "# 2 tsp. caraway seeds and ½ cup shredded Cheddar cheese.\n",
    "# ½ cup diced Swiss cheese and paprika.\n",
    "# ..."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
